from pymongo import MongoClient
from sklearn.model_selection import train_test_split
import numpy as np
import pickle as pkl
from tokenizer import tokenize
from sklearn import preprocessing
import csv
client = MongoClient(port=27017)
db=client['tiktok']
sequence_len = 20
#sequence_len_sticker=30

c=0
#collist=['inkdrawing','worldseries','smallbusiness','lunarnewyear',  'selfimprovement', 'productivity','comfortfood', 'workingathome']
#collist=[ 'familyimpression','rnbvibes','festivefashion' ]
#collist=['holidayvibes','gamenight','happyholidays','carsoftiktok','fallguysmoments','interiordesign','homecooked','veteransday','youwantmore','interiordesign','coldweather',
#         'wildanimals','mycostume']
#collist=['meleaving','mypfp','catchphrases','watchmegrow','holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear',
#        'tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch']
#collist=['givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday']
#collist=['givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','smallbusiness''productivity','smallbusiness','worldseries','onlinedating','artmas']
#collist=['holidayvibes','gamenight','happyholidays','carsoftiktok','fallguysmoments','interiordesign','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow','holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear',
#        'tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch','givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','productivity','smallbusiness']
#collist=['falldiy','whenwewereyounger','yellow','ComingOfAge','artmas','gaminglife','gamingsetup','hellowinter','planttiktok','housetour','neonshadow','homeoffice','raisedby','makeitvogue','foodtiktok','valentinesday','yougotthis','stemlife']
collist=[ #['holidayvibes','gamenight','happyholidays','carsoftiktok','fallguysmoments','interiordesign','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow']#,
        #'holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear','tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch','givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','productivity','smallbusiness']
      'falldiy','whenwewereyounger','yellow','ComingOfAge','artmas','gaminglife','gamingsetup','hellowinter','planttiktok','housetour','neonshadow','homeoffice','raisedby','makeitvogue','foodtiktok','valentinesday','yougotthis','stemlife']

tidlist={}
with open('D:\\Work\\Tool\\tiktok\\all\\infos_d7.tsv', 'r', encoding='utf-8',
          newline='\n') as filename_output:
    reader = csv.reader(filename_output, delimiter='\t')
    next(reader)
    for line in reader:
        if line[1] not in tidlist.keys():
            tidlist[line[1]]=int(line[5])
        elif int(line[5])<tidlist[line[1]]:
            tidlist[line[1]] = int(line[5])
embedding_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')
embedding_matrix_norm = np.load('E:\\data_pi\\embedding_matrix_norm.npy')
word_index = pkl.load(open('E:\\data_pi\\word_index.pkl','rb'))
htc=0
#for col in db.list_collection_names():
for col in collist:
    idlist = []
    img = {}
    yamnet = {}
    texts = {}
    texts_sticker={}
    labels = {}
    edits = {}
    htids = {}
    img_embed={}
    print(col)
    #cursor=db[col].find(no_cursor_timeout=True)
    for obj in db[col].find():
        if obj['_id'] not in tidlist.keys():
            continue
        if  tidlist[obj['_id']]<400 or tidlist[obj['_id']]>1500:
            continue
        #if len(obj['video_feature']['img_embed'])>0 and len(obj['video_feature']['audio']['yamnet'])>0 and len(obj['text_feature']['text'])>0 and(len(obj['video_feature']['label'])) and 'aesthetics_feature' in obj['video_feature']['editing'].keys() and not np.isnan(obj['video_feature']['editing']['aesthetics_feature']).any():
        if len(obj['video_feature']['img_embed']) > 0 and len(obj['text_feature']['text'])>0 and len(obj['video_feature']['audio']['yamnet'])>0 and(len(obj['video_feature']['label'])>0)and (len(obj['video_feature']['residual'])>0) and('labelA_n' in obj['video_feature']['label'].keys()) and('var_sb' in obj['video_feature']['editing'].keys()) and ('avg_sticker_length' in obj['video_feature']['editing'].keys()) and ('avg_scences' in obj['video_feature']['editing'].keys()):
        #if len(obj['video_feature']['img_embed'])>0 and len(obj['video_feature']['audio']['yamnet'])>0 and len(obj['text_feature']['text'])>0 and(len(obj['video_feature']['label'])>0) and (len(obj['video_feature']['residual'])>0) :
            c+=1
            idlist.append(obj['_id'])
            img_embed[obj['_id']]=[]
            maximg=obj['video_feature']['img_embed']['0'][4096:]
            for key in obj['video_feature']['img_embed'].keys():
                img_embed[obj['_id']].append(obj['video_feature']['img_embed'][key][:4096])
                for i in range(4096,len(obj['video_feature']['img_embed'][key])):
                    if maximg[i-4096]<obj['video_feature']['img_embed'][key][i]:
                        maximg[i-4096]=obj['video_feature']['img_embed'][key][i]
            img[obj['_id']]=maximg
            yamnet[obj['_id']]=obj['video_feature']['audio']['yamnet']

            temptext=obj['text_feature']['text']
            temptext_sticker=''
            for item in obj['text_feature']['stickerText']:
                temptext_sticker+=item+' '
            texts[obj['_id']]=temptext
            texts_sticker[obj['_id']]=temptext_sticker
            #labels[obj['_id']]=[int(obj['video_feature']['label']['labelA']),int(obj['video_feature']['label']['labelB']),float(obj['video_feature']['residual']['residualA']),float(obj['video_feature']['residual']['residualB'])]
            labelc=0
            if int(obj['video_feature']['label']['labelA_n'])==int(obj['video_feature']['label']['labelB_n']) and int(obj['video_feature']['label']['labelA_n'])==1 and int(obj['video_feature']['label']['labelB_n'])==1:
                labelc=1
            labels[obj['_id']] = [int(obj['video_feature']['label']['labelA']),
                              int(obj['video_feature']['label']['labelB']),
                              int(obj['video_feature']['label']['labelA_n']),int(obj['video_feature']['label']['labelB_n']),labelc]

            htids[obj['_id']]=col
            edits[obj['_id']]=[obj['video_feature']['editing']['video_len'],obj['video_feature']['editing']['var_sb'],obj['video_feature']['editing']['var_sb_c'],obj['video_feature']['editing']['var_vgg'],obj['video_feature']['editing']['var_vgg_c'],obj['video_feature']['editing']['var_yamnet'],obj['video_feature']['editing']['sticker_num'],obj['video_feature']['editing']['avg_sticker_length'],obj['video_feature']['editing']['avg_scences']]
            #edits[obj['_id']].extend(list(preprocessing.normalize([obj['video_feature']['editing']['aesthetics_feature']])[0]))
    #cursor.close()
    img_sequence_len=13
    train_img_emb=[]
    train_yamnet=[]
    train_text=[]
    train_text_embed=[]
    train_text_embed_norm=[]
    train_text_sticker = []
    train_text_embed_sticker = []
    train_text_embed_norm_sticker = []
    train_labels=[]
    train_edits=[]
    train_ids=[]
    train_img_emb_new=[]

    for item in idlist:
        train_ids.append(item)
        train_img_emb.append(img[item])
        train_yamnet.append(yamnet[item])
        words = tokenize(texts[item])
        text = []
        text_embed = []
        text_embed_norm = []
        for word in words[:sequence_len]:
            if word in word_index:
                text.append(word_index[word])
            else:
                continue
            text_embed.append(embedding_matrix[text[-1]])
            text_embed_norm.append(embedding_matrix_norm[text[-1]])
        while len(text) < sequence_len:
            text.append(0)
            text_embed.append(np.zeros(embedding_matrix.shape[1]))
            text_embed_norm.append(np.zeros(embedding_matrix_norm.shape[1]))

        words_sticker = tokenize(texts_sticker[item])
        text_sticker = []
        text_embed_sticker = []
        text_embed_norm_sticker = []
        for word_sticker in words_sticker[:sequence_len]:
            if word_sticker in word_index:
                text_sticker.append(word_index[word_sticker])
            else:
                continue
            text_embed_sticker.append(embedding_matrix[text_sticker[-1]])
            text_embed_norm_sticker.append(embedding_matrix_norm[text_sticker[-1]])
        while len(text_sticker) < sequence_len:
            text_sticker.append(0)
            text_embed_sticker.append(np.zeros(embedding_matrix.shape[1]))
            text_embed_norm_sticker.append(np.zeros(embedding_matrix_norm.shape[1]))

        image_embeds=img_embed[item]
        im=[]
        for image in image_embeds[:img_sequence_len]:
            im.append(image)
        while len(im)< img_sequence_len:
            im.append(np.zeros(4096))
        train_img_emb_new.append(im)
        train_text.append(text)
        train_text_embed.append(text_embed)
        train_text_embed_norm.append(text_embed_norm)

        train_text_sticker.append(text_sticker)
        train_text_embed_sticker.append(text_embed_sticker)
        train_text_embed_norm_sticker.append(text_embed_norm_sticker)

        train_labels.append(labels[item])
        train_edits.append(edits[item])
    train_img_emb = np.array(train_img_emb).astype(np.float32)
    train_yament = np.array(train_yamnet).astype(np.float32)
    train_img_emb_new = np.array(train_img_emb_new).astype(np.float32)
    train_edits = np.array(train_edits).astype(np.float32)
    np.save('E:\\data_pi\\image_embed_new_'+col, np.array(train_img_emb))
    np.save('E:\\data_pi\\yamnet_embed_new_'+col, np.array(train_yamnet))
    np.save('E:\\data_pi\\label_new_'+col, np.array(train_labels))
    np.save('E:\\data_pi\\text_new_'+col, train_text)
    np.save('E:\\data_pi\\edit_embed_new_'+col, train_edits)
    np.save('E:\\data_pi\\image_embed_new_'+col, np.array(train_img_emb_new))

    np.save('E:\\data_pi\\ids_4h-15h_new_'+col, train_ids)

    np.save('E:\\data_pi\\text_embed_new_'+col, train_text_embed)
    np.save('E:\\data_pi\\text_embed_norm_new_'+col, train_text_embed_norm)

    np.save('E:\\data_pi\\text_sticker_new_'+col, train_text_sticker)
    np.save('E:\\data_pi\\text_sticker_embed_new_' + col, train_text_embed_sticker)
    np.save('E:\\data_pi\\text_sticker_embed_norm_new_' + col, train_text_embed_norm_sticker)
